import pandas as pd
from transformers import pipeline, logging
from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score, balanced_accuracy_score, precision_recall_fscore_support, classification_report
from datasets import load_dataset, DatasetDict, Dataset, disable_progress_bars
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import torch
import numpy as np
import time
disable_progress_bars()
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
pd.options.mode.chained_assignment = None 
logging.set_verbosity_error()
start_time = time.time()

###########################
## Set device from CLI args
###########################
import os
import argparse
def get_device():
    # Allow --device on the command line
    parser = argparse.ArgumentParser()
    parser.add_argument("--device", default=None, help="Device to use: cpu, mps, cuda")
    args, _ = parser.parse_known_args()  # ignore unknown args so script still works normally
    # Priority: CLI argument > DEVICE environment variable > default=cpu
    return args.device or os.getenv("DEVICE", "cuda")
DEVICE = get_device()
print(f"[INFO] Using DEVICE={DEVICE}")
# PyTorch setup
import torch
if DEVICE == "cpu":
    device = torch.device("cpu")
elif DEVICE == "mps":
    device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
elif DEVICE == "cuda":
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
else:
    raise ValueError(f"Unknown device {DEVICE}")
print(f"[INFO] Torch device set to {device}")

###########################
## Data, functions, etc.
###########################
modname = "mlburnham/Political_DEBATE_DeBERTa_large_v1.1"

# Tokenize the documents
def tokenize_function(docs):
    return tokenizer(docs['premise'], docs['hypothesis'], padding = 'max_length', truncation = True)

# Function for computing performance during training
def compute_metrics_standard(eval_pred, label_text_alphabetical = ['entailment', 'not_entailment']):
    labels = eval_pred.label_ids
    pred_logits = eval_pred.predictions
    preds_max = np.argmax(pred_logits, axis=1)

    # metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(labels, preds_max, average='macro') 
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(labels, preds_max, average='micro')
    acc_balanced = balanced_accuracy_score(labels, preds_max)
    acc_not_balanced = accuracy_score(labels, preds_max)
    mcc = matthews_corrcoef(labels, preds_max)

    metrics = {'MCC': mcc,
            'f1_macro': f1_macro,
            'f1_micro': f1_micro,
            'accuracy_balanced': acc_balanced,
            'accuracy': acc_not_balanced,
            'precision_macro': precision_macro,
            'recall_macro': recall_macro,
            'precision_micro': precision_micro,
            'recall_micro': recall_micro,
            }
    print("Aggregate metrics: ", {key: metrics[key] for key in metrics if key not in ["label_gold_raw", "label_predicted_raw"]} )
    print("Detailed metrics: ", classification_report(
        labels, preds_max, labels=np.sort(pd.factorize(label_text_alphabetical, sort=True)[0]),
        target_names=label_text_alphabetical, sample_weight=None,
        digits=2, output_dict=True, zero_division='warn'),
    "\n")

    return metrics

# Import Data
df = pd.read_csv('./data/covid_tweets_labeled.csv')
df = df[['text', 'non_comp']]
df['hypothesis'] = 'The author of this tweet does not believe COVID is dangerous.'
df.rename({'text':'premise', 'non_comp':'entailment'}, axis = 1, inplace = True)
df['entailment'].replace({0:1, 1:0}, inplace = True)

#############
## Few Shot
#############
# Instantiate Model
model = AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)
tokenizer = AutoTokenizer.from_pretrained(modname)

training_args = TrainingArguments(output_dir = 'few_shot',
    lr_scheduler_type= "linear",
    group_by_length=False,
    learning_rate = 9e-6,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 16,
    gradient_accumulation_steps = 1, 
    num_train_epochs=5,
    warmup_ratio=0.06,  
    weight_decay=0.01, 
    bf16=True,
    eval_strategy="no",
    seed=1,
    save_strategy="no",
    dataloader_num_workers = 0,
    disable_tqdm=True
)

tokenizer = AutoTokenizer.from_pretrained(modname)

# Define a function to initialize the modelin the trainer. This will make results reproducible
def model_init():
    return AutoModelForSequenceClassification.from_pretrained(modname, num_labels = 2, ignore_mismatched_sizes=True)

# Define the number of samples (shots) and random seeds to use
shots = [10, 25, 50, 100]

# Initialize lists to store results
mcc_list = []
acc_list = []
f1_list = []
shots_list = []

# Iterate through different shot sizes
for shot in shots:
    # Iterate through different random seeds
    for seed in range(0,20):
        print("DEBATE Base seed {}, {} shots".format(seed, shot))
        # Sample training data based on current shot size and seed
        train = df.sample(shot, random_state = seed)
        # Create validation set with remaining instances
        val = df[~df.index.isin(train.index)]
        
        # Create a DatasetDict with train and validation splits
        ds = DatasetDict({'train': Dataset.from_pandas(train, preserve_index=False), 'validation':Dataset.from_pandas(val, preserve_index=False)})
        # Tokenize the dataset
        dstok = ds.map(tokenize_function, batched = True)
        # Rename 'entailment' column to 'label'
        dstok = dstok.rename_columns({'entailment':'label'})
        # Define label mapping
        id2label = {0: "entailment", 1: "not_entailment"}
        
        # Initialize the Trainer
        trainer = Trainer(
            model_init = model_init,
            tokenizer=tokenizer,
            args=training_args,
            train_dataset=dstok['train'],
            eval_dataset=dstok['validation'],
            compute_metrics=lambda x: compute_metrics_standard(x, label_text_alphabetical=list(model.config.id2label.values()))
        )
        
        # Train the model
        trainer.train()
        # Make predictions on validation set
        res = trainer.predict(dstok['validation'])
        preds = np.argmax(res.predictions, axis=-1)
        
        # Calculate Matthews Correlation Coefficient
        mcc_res = matthews_corrcoef(val['entailment'], preds)
        mcc_list.append(mcc_res)
        # Calculate Accuracy
        acc_res = accuracy_score(val['entailment'], preds)
        acc_list.append(acc_res)
        # Calculate F1
        f1_res = f1_score(val['entailment'], preds)
        f1_list.append(f1_res)
        # Store the current shot size
        shots_list.append(shot)
    
res_df = pd.DataFrame({'n':shots_list, 'f1':f1_list, 'mcc':mcc_list, 'accuracy':acc_list})

#############
## Zero Shot
#############
pipe = pipeline("zero-shot-classification", model = modname, device = device, batch_size = 32)
res = pipe(list(df['premise'].str.lower()), ['The author of this tweet does not believe COVID is dangerous'], hypothesis_template = '{}.', multi_label = False)
labels = [round(label['scores'][0], 0) for label in res]
df['zs_large'] = labels
df['zs_large'] = df['zs_large'].replace({0:1, 1:0})

zs_df = pd.DataFrame({'n':0, 'mcc':matthews_corrcoef(df['entailment'], df['zs_large']), 
              'f1':f1_score(df['entailment'], df['zs_large'], pos_label = 0), 
              'accuracy':accuracy_score(df['entailment'], df['zs_large'])}, index = [0])

res_df = pd.concat([zs_df, res_df], axis = 0)

##########
## Export
##########
res_df.to_csv('./data/covid_fewshot_large.csv', index = False)
run_time = time.time() - start_time
print(run_time) # 1.5 hours on 3090, 8 hours on M3